/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/

#pragma once

namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Varlen = true>
struct BlockInfo {

    template <typename Params>
    __device__ BlockInfo(const Params& params, const int bidb)
        : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
          sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative
                      ? -1
                      : params.cu_seqlens_k[bidb]),
          actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr
                              ? params.seqlen_q
                              : params.cu_seqlens_q[bidb + 1] - sum_s_q)
          // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] -
          // cu_seqlens_k[bidb]. Otherwise it's cu_seqlens_k[bidb], i.e., we use
          // cu_seqlens_k to store the sequence lengths of K.
          ,
          seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr
                             ? params.seqlen_k
                             : (params.is_seqlens_k_cumulative
                                    ? params.cu_seqlens_k[bidb + 1] - sum_s_k
                                    : params.cu_seqlens_k[bidb])),
          actual_seqlen_k(params.seqused_k
                              ? params.seqused_k[bidb]
                              : seqlen_k_cache +
                                    (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {}

    template <typename index_t>
    inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride,
                                       const int bidb) const {
        return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
    }

    template <typename index_t>
    inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride,
                                       const int bidb) const {
        return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
    }

    const int sum_s_q;
    const int sum_s_k;
    const int actual_seqlen_q;
    // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise
    // actual_seqlen_k is set to 0.
    const int seqlen_k_cache;
    const int actual_seqlen_k;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

}    // namespace flash
